{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "c2218b66",
   "metadata": {},
   "source": [
    "# Deep Dive into Text-Image Search Engine with Towhee\n",
    "\n",
    "In the [previous tutorial](./1_build_text_image_search_engine.ipynb), we built and prototyped a proof-of-concept image search engine. Now, let's finetune it with our own image datasets, and deploy it with accleration service."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "ae6b056f",
   "metadata": {},
   "source": [
    "## Finetune Text-Image Search on Custom Dataset\n",
    "\n",
    "### Install Dependencies\n",
    "\n",
    "Firstly, we need to install dependencies such as towhee and opencv-python. And please make sure that you have started a [Milvus service](https://milvus.io/docs/install_standalone-docker.md). This notebook uses [milvus 2.2.10](https://milvus.io/docs/v2.2.x/install_standalone-docker.md) and [pymilvus 2.2.11](https://milvus.io/docs/release_notes.md#2210)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "bca1652c",
   "metadata": {},
   "outputs": [],
   "source": [
    "! python -m pip -q install towhee opencv-python pymilvus==2.2.11"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "ba622fa5",
   "metadata": {},
   "source": [
    "### Prepare the Data\n",
    "\n",
    "For text-image search, we use CIFAR-10 dataset as an example to show how to finetune CLIP model for users' customized dataset. CIFAR-10 dataset contains 60,000 32x32 color images in 10 different classes. It is widely used as an image recognition benchmark for various computer vision models. In this example, we manually create the caption by creating the sentence with its corresponding label.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "5152da61",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "import torchvision\n",
    "import os\n",
    "import json\n",
    "\n",
    "\n",
    "root_dir = '/tmp/'\n",
    "train_dataset = torchvision.datasets.CIFAR10(root=root_dir, train=True, download=True)\n",
    "eval_dataset = torchvision.datasets.CIFAR10(root=root_dir, train=False, download=True)\n",
    "\n",
    "\n",
    "idx = 0\n",
    "def build_image_text_dataset(root, folder, dataset):\n",
    "    results = []\n",
    "    global idx\n",
    "    labels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']\n",
    "    if not os.path.exists(os.path.join(root,folder)):\n",
    "        os.mkdir(os.path.join(root,folder))\n",
    "    for img, label_idx in dataset:\n",
    "        item  = {}\n",
    "        imgname = \"IMG{:06d}.png\".format(idx)\n",
    "        filename = os.path.join(root, folder, imgname)\n",
    "        idx = idx + 1\n",
    "        caption = 'this is a picture of {}.'.format(labels[label_idx])\n",
    "        img.save(filename)\n",
    "        item['caption_id'] = idx\n",
    "        item['image_id'] = idx\n",
    "        item['caption'] = caption\n",
    "        item['image_path'] = filename\n",
    "        results.append(item)\n",
    "    return results\n",
    "\n",
    "def gen_caption_meta(root, name, meta):\n",
    "    save_path = os.path.join(root, name+'.json')\n",
    "    with open(save_path, 'w') as fw:\n",
    "        fw.write(json.dumps(meta, indent=4))\n",
    "\n",
    "train_results = build_image_text_dataset(root_dir, 'train', train_dataset)\n",
    "gen_caption_meta(root_dir, 'train', train_results)\n",
    "\n",
    "eval_results = build_image_text_dataset(root_dir, 'eval', eval_dataset)\n",
    "gen_caption_meta(root_dir, 'eval', eval_results)\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "2b111adf",
   "metadata": {},
   "source": [
    "Now we have an image-text annotation of CIFAR-10\n",
    "\n",
    "|caption ID|image ID | caption   |  image  | image path|\n",
    "|:--------|:-------- |:----------|:--------|:----------|\n",
    "| 0 | 0 | this is a picture of frog.|  <img src=\"train/IMG000000.png\" max-width=\"50\" width=\"50\" height=\"50\">| /tmp/train/IMG000000.png |\n",
    "| 1 | 1 | this is a picture of truck. |  <img src=\"train/IMG000001.png\" max-width=\"50\" width=\"50\" height=\"50\">| /tmp/train/IMG000001.png |\n",
    "| 2 | 2 | this is a picture of truck. |  <img src=\"train/IMG000002.png\" max-width=\"50\" width=\"50\" height=\"50\">| /tmp/train/IMG000002.png  |\n",
    "| 3 | 3 | this is a picture of deer.|  <img src=\"train/IMG000003.png\" max-width=\"50\" width=\"50\" height=\"50\">| /tmp/train/IMG000003.png  |\n",
    "| 4 | 4 | this is a picture of automobile.|  <img src=\"train/IMG000004.png\" max-width=\"50\" width=\"50\" height=\"50\">| /tmp/train/IMG000004.png  |"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e08849b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import towhee\n",
    "from towhee import ops\n",
    "#step1\n",
    "#get the operator, modality has no effect to the training model, it is only for the inference branch selection.\n",
    "clip_op = ops.image_text_embedding.clip(model_name='clip_vit_base_patch16', modality='image').get_op()\n",
    "\n",
    "\n",
    "#step2\n",
    "#trainer configuration, theses parameters are huggingface-style standard training configuration.\n",
    "data_args = {\n",
    "    'dataset_name': None,\n",
    "    'dataset_config_name': None,\n",
    "    'train_file': '/tmp/train.json',\n",
    "    'validation_file': '/tmp/eval.json',\n",
    "    'max_seq_length': 77,\n",
    "    'data_dir': None,\n",
    "    'image_mean': [0.48145466, 0.4578275, 0.40821073],\n",
    "    \"image_std\": [0.26862954, 0.26130258, 0.27577711]\n",
    "}\n",
    "\n",
    "training_args = {\n",
    "    'num_train_epochs': 32, # you can add epoch number to get a better metric.\n",
    "    'per_device_train_batch_size': 64,\n",
    "    'per_device_eval_batch_size': 64,\n",
    "    'do_train': True,\n",
    "    'do_eval': True,\n",
    "    'eval_steps':1,\n",
    "    'remove_unused_columns': False,\n",
    "    'dataloader_drop_last': True,\n",
    "    'output_dir': './output/train_clip_exp',\n",
    "    'overwrite_output_dir': True,\n",
    "}\n",
    "\n",
    "model_args = {\n",
    "    'freeze_vision_model': False,\n",
    "    'freeze_text_model': False,\n",
    "    'cache_dir': './cache'\n",
    "}\n",
    "\n",
    "#step3\n",
    "#train your model\n",
    "clip_op.train(data_args=data_args, training_args=training_args, model_args=model_args)\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "9994a9f4",
   "metadata": {},
   "source": [
    "CLIP operator uses standard Hugging Face [Transformers](https://github.com/huggingface/transformers) training procedure to finetune the model. The detail of training configuration can be found at [transformers doc](https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.TrainingArguments).\n",
    "When training procedure is finished, we can load the trained weights into the operator."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6afd13cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "ops.image_text_embedding.clip(model_name='clip_vit_base_patch16', modality='image', checkpoint_path='./output/train_clip_exp/checkpoint-5000/pytorch_model.bin')"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "414cfe20",
   "metadata": {},
   "source": [
    "## Making Our Text-Image Search Pipeline Production Ready\n",
    "\n",
    "The text-image pipeline now can be finetuned on customized dataset to get the gain from specific dataset. To put the text-image search engine into production, we also need to execute the whole pipeline in a highly-efficient way instead  of original PyTorch execution.\n",
    "\n",
    "Towhee supports NVIDIA Triton Inference Server to improve performance for inferencing data for production-ready services. The supported model can be transfered to a Triton service just in a few lines.\n",
    "\n",
    "Operators can be packed into a Triton service for better inferencing performance. Some specific models of operator can be exported to ONNX models and achieve better acceleration (default is TorchScript).\n",
    "\n",
    "Before getting started, please make sure you have built `text_image_search` collection that uses the [L2 distance metric](https://milvus.io/docs/v2.0.x/metric.md#Euclidean-distance-L2) and an [IVF_FLAT index](https://milvus.io/docs/v2.0.x/index.md#IVF_FLAT) as the [previous tutorial](./1_build_text_image_search_engine.ipynb).\n",
    "\n",
    "### Check Operator \n",
    "Firstly, we need to check if the operator can be transfered to ONNX."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "20bb99a0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "full model list: ['clip_vit_base_patch16', 'clip_vit_base_patch32', 'clip_vit_large_patch14', 'clip_vit_large_patch14_336']\n",
      "onnx model list: ['clip_vit_base_patch16', 'clip_vit_base_patch32', 'clip_vit_large_patch14', 'clip_vit_large_patch14_336']\n"
     ]
    }
   ],
   "source": [
    "from towhee import ops, pipe\n",
    "import numpy as np\n",
    "\n",
    "op = ops.image_text_embedding.clip(model_name='clip_vit_base_patch16', modality='image').get_op()\n",
    "full_list = op.supported_model_names()\n",
    "onnx_list = op.supported_model_names(format='onnx')\n",
    "\n",
    "print('full model list:', full_list)\n",
    "print('onnx model list:', onnx_list)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "8e8f12cf",
   "metadata": {},
   "source": [
    "All candidate models of CLIP can be transfered to ONNX model for the Triton pipeline acceleration.\n",
    "\n",
    "### Build Docker Service"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81f1f0f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "op = ops.image_text_embedding.clip(model_name='clip_vit_base_patch16', modality='text').get_op()\n",
    "\n",
    "#your host machine IP address, e.g. 192.158.1.38\n",
    "ip_addr = '192.158.1.38'\n",
    "\n",
    "#make sure you have built Milvus collection successfully.\n",
    "p_search = (\n",
    "    pipe.input('text')\n",
    "        .map('text', 'vec', ops.image_text_embedding.clip(model_name='clip_vit_base_patch16', modality='text'), config={'device': 0})\n",
    "        .map('vec', 'vec', lambda x: x / np.linalg.norm(x))\n",
    "        .map('vec', ('search_res'), ops.ann_search.milvus_client(host=ip_addr, port='19530', limit=5, collection_name=\"text_image_search\", output_fields=['url']))\n",
    "        .output('text','search_res')\n",
    ")\n",
    "\n",
    "towhee.build_docker_image(\n",
    "    dc_pipeline=p_search,\n",
    "    image_name='text_image_search:v1',\n",
    "    cuda_version='11.7', # '117dev' for developer\n",
    "    format_priority=['onnx'],\n",
    "    inference_server='triton'\n",
    ")\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "226f16ba",
   "metadata": {},
   "source": [
    "After the docker image is built, the inferencing service and its associated model is resident in it. Start the service by running a docker container."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "a27395a2",
   "metadata": {},
   "source": [
    "```console\n",
    "docker run -td --gpus=all --shm-size=1g \\\n",
    "    -p 8000:8000 -p 8001:8001 -p 8002:8002 \\\n",
    "    text_image_search:v1 \\\n",
    "    tritonserver --model-repository=/workspace/models\n",
    "```"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "4e3bafa9",
   "metadata": {},
   "source": [
    "### Inference with Triton Service\n",
    "Now we can use a client to visit the accelerated service."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "c1be1ef7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "idx: 96, distance_score:1.35 , path: ./train/Bouvier_des_Flandres/n02106382_8906.JPEG\n",
      "idx: 506, distance_score:1.38 , path: ./train/Doberman/n02107142_4753.JPEG\n",
      "idx: 835, distance_score:1.38 , path: ./train/Afghan_hound/n02088094_3882.JPEG\n",
      "idx: 507, distance_score:1.39 , path: ./train/Doberman/n02107142_32921.JPEG\n",
      "idx: 832, distance_score:1.39 , path: ./train/Afghan_hound/n02088094_6565.JPEG\n"
     ]
    }
   ],
   "source": [
    "from towhee import triton_client\n",
    "\n",
    "client = triton_client.Client(url='localhost:8000')\n",
    "\n",
    "data = \"a black dog.\"\n",
    "res = client(data)\n",
    "\n",
    "for idx, dis_score, path in res[0][1]:\n",
    "    print('idx: {}, distance_score:{:.2f} , path: {}'.format(idx, dis_score, path))\n",
    "client.close()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
